/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "execute_graph_builder.h"

namespace ge {
namespace {
bool ValidEdgeIndexes(int src_index, int dst_index, bool &is_data_edge) {
  if (src_index == -1 && dst_index == -1) {
    is_data_edge = false;
    return true;
  }
  if (src_index >= 0 && dst_index >= 0) {
    is_data_edge = true;
    return true;
  }
  return false;
}
}

PoolOffset ExecuteGraphBuilder::AddString(const char *str, bool &exists) {
  if (str == nullptr) {
    return kInvalidOffset;
  }
  auto iter = strs_to_offset_.find(str);
  if (iter == strs_to_offset_.end()) {
    exists = false;
    auto offset = str_pool_.AddBuffer(str, strlen(str) + 1);
    strs_to_offset_[str] = offset;
    return offset;
  } else {
    exists = true;
    return iter->second;
  }
}
PoolOffset ExecuteGraphBuilder::AddString(const char *str) {
  bool exists;
  return AddString(str, exists);
}
PoolOffset ExecuteGraphBuilder::AddOpDef(const char *op_type) {
  if (op_type == nullptr) {
    return kInvalidOffset;
  }

  auto iter = types_to_op_def_offset_.find(op_type);
  if (iter == types_to_op_def_offset_.end()) {
    auto offset = op_def_pool_.size();
    types_to_op_def_offset_[op_type] = offset;
    op_def_pool_.emplace_back(this);
    op_def_pool_[offset].SetType(op_type);
    return offset;
  } else {
    return iter->second;
  }
}

PoolOffset ExecuteGraphBuilder::AddNode() {
  auto offset = nodes_.size();
  nodes_.emplace_back(new (std::nothrow) ExecuteNodeBuilder(this, offset));
  return offset;
}
ExecuteGraphBuilder &ExecuteGraphBuilder::AddEdge(PoolOffset src_node, int src_index, PoolOffset dst_node, int dst_index) {
  bool is_data_edge;
  if (!ValidEdgeIndexes(src_index, dst_index, is_data_edge)) {
    // invalid src and dst offset, log error
    error_code_ = FAILED;
    return *this;
  }
  if (src_node >= nodes_.size() || dst_node >= nodes_.size()) {
    error_code_ = FAILED;
    return *this;
  }
  auto insert_ret = dst_nodes_to_edges_[dst_node].emplace(this, src_node, dst_node, src_index, dst_index);
  if (!insert_ret.second) {
    return *this;
  }

  edge_num_++;
  return *this;
}
ExecuteGraphBuilder &ExecuteGraphBuilder::AddControlEdge(PoolOffset src_node, PoolOffset dst_node) {
  return AddEdge(src_node, -1, dst_node, -1);
}
std::unique_ptr<ExecuteGraph> ExecuteGraphBuilder::Build() {
  if (error_code_ != SUCCESS) {
    return nullptr;
  }

  auto ret = BuildCheck();
  if (ret != SUCCESS) {
    return nullptr;
  }

  auto graph = std::unique_ptr<ExecuteGraph>(new (std::nothrow) ExecuteGraph());
  if (graph == nullptr) {
    return nullptr;
  }

  ret = BuildPools(graph.get());
  if (ret != SUCCESS) {
    return nullptr;
  }
  ret = BuildGraphAfterPool(graph.get());
  if (ret != SUCCESS) {
    return nullptr;
  }
  return graph;
}

Status ExecuteGraphBuilder::BuildGraphAfterPool(ExecuteGraph *graph) {
  GraphPoolReader pool_reader{graph};
  for (size_t i = 0; i < op_def_pool_.size(); ++i) {
    auto ret = op_def_pool_[i].Build(&pool_reader, &graph->op_def_pool_[i]);
    if (ret != SUCCESS) {
      return ret;
    }
  }

  for (size_t i = 0; i < nodes_.size(); ++i) {
    nodes_[i]->SetId(static_cast<int64_t>(i));
    auto ret = nodes_[i]->BuildNode(&pool_reader, &graph->node_pool_[i]);
    if (ret != SUCCESS) {
      return ret;
    }
  }

  size_t i = 0;
  for (auto &dst_node_to_edges : dst_nodes_to_edges_) {
    for (auto &edge : dst_node_to_edges.second) {
      auto ret = edge.Build(&pool_reader, &graph->edge_pool_[i++]);
      if (ret != SUCCESS) {
        return ret;
      }
    }
  }
  return SUCCESS;
}
Status ExecuteGraphBuilder::BuildPools(ExecuteGraph *graph) {
  graph->str_pool_.Swap(str_pool_);

  graph->op_def_pool_.resize(op_def_pool_.size());
  graph->node_pool_.resize(nodes_.size());
  graph->edge_pool_.resize(edge_num_);
  return BuildTensorDescPool(graph);
}
Status ExecuteGraphBuilder::BuildTensorDescPool(ExecuteGraph *graph) {
  PoolOffset free_td = 0;
  std::map<std::pair<PoolOffset, int>, PoolOffset> out_tensors_to_td;
  // 进入此循环前，需要保证拓扑顺序是正确的，否则会出现实际已经连边，但是输入找不到对端输出tensor_desc的情况
  for (auto &node_builder : nodes_) {
    std::vector<const EdgeBuilder *> data_edges{node_builder->GetInputNum(), nullptr};
    auto iter = dst_nodes_to_edges_.find(node_builder->GetOffset());
    if (iter != dst_nodes_to_edges_.end()) {
      for (auto &edge : iter->second) {
        auto dst_index = edge.GetDstIndex();
        if (dst_index < 0) {
          continue;
        }
        // dst_index的有效性在BuildCheck中做过了
        data_edges[dst_index] = &edge;
      }
    }

    for (size_t i = 0; i < node_builder->GetInputNum(); ++i) {
      PoolOffset td;
      if (data_edges[i] != nullptr) {
        auto src_node_offset = data_edges[i]->GetSrcNode();
        auto src_node = nodes_[src_node_offset].get();
        auto tensor_id = std::make_pair<PoolOffset, int>(src_node->GetOffset(), data_edges[i]->GetSrcIndex());
        auto tensors_iter = out_tensors_to_td.find(tensor_id);
        if (tensors_iter == out_tensors_to_td.end()) {
          // TODO: 走到这一步，说明有edge，edge对端输出的TensorDesc却还没有分配，有可能是拓扑序出现问题了
          //       这里没有考虑v1版本循环的场景
          return FAILED;
        }
        td = tensors_iter->second;
      } else {
        td = free_td++;
      }
      node_builder->SetInputDesc(static_cast<int>(i), td);
    }

    // TODO 当前暂未考虑**引用输出**的情况，实际当前暂未考虑的不同的位置相同tensor的情况还很多
    //      最准确的做法是，通过符号表找到等价Tensor做一次
    for (size_t i = 0; i < node_builder->GetOutputNum(); ++i) {
      node_builder->SetOutputDesc(static_cast<int>(i), free_td);
      auto tensor_id = std::make_pair<PoolOffset, int>(node_builder->GetOffset(), static_cast<int>(i));
      out_tensors_to_td[tensor_id] = free_td++;
    }
  }
  graph->td_pool_.resize(free_td);
  return SUCCESS;
}
ExecuteNodeBuilder *ExecuteGraphBuilder::GetNodeBuilder(PoolOffset offset) {
  if (offset >= nodes_.size()) {
    return nullptr;
  }
  return nodes_[offset].get();
}
OpDefBuilder *ExecuteGraphBuilder::GetOpDefBuilder(PoolOffset offset) {
  if (offset >= op_def_pool_.size()) {
    return nullptr;
  }
  return &(op_def_pool_[offset]);
}
Status ExecuteGraphBuilder::BuildCheck() const {
  // todo 检查Node上的Name是否与OpDef对应
  // todo 检查Edge上的Index是否在Node的范围内
  for (auto &dst_node_to_edges : dst_nodes_to_edges_) {
    for (auto &edge : dst_node_to_edges.second) {
      if (edge.GetSrcNode() >= nodes_.size()) {
        return FAILED;
      }
      if (edge.GetSrcIndex() >= 0 && edge.GetSrcIndex() >= nodes_[edge.GetSrcNode()]->GetOutputNum()) {
        return FAILED;
      }
      if (edge.GetDstNode() >= nodes_.size()) {
        return FAILED;
      }
      if (edge.GetDstIndex() >= 0 && edge.GetDstIndex() >= nodes_[edge.GetDstNode()]->GetInputNum()) {
        return FAILED;
      }
    }
  }
  // todo 检查各种Offset是否都存在
  for (auto &node : nodes_) {
    if (node->GetOffset() >= nodes_.size()) {
      return FAILED;
    }
  }
  return SUCCESS;
}
ExecuteGraphBuilder::~ExecuteGraphBuilder() = default;
}  // namespace ge